#!/usr/bin/env python3

if __name__ == '__main__':
    import os, sys, torch

    arguments = sys.argv[1:]
    # file name
    run_file_name = arguments[0]

    extra_args = ''
    if len(arguments) > 1:
        extra_args = f"{' '.join(arguments[1:])} "

    try:
        import deepspeed
        parallel_type = 'ds'
    except:
        gpu_count = torch.cuda.device_count()
        if gpu_count <= 1:
            parallel_type = 'none'
        else:
            parallel_type = 'ddp'

    os.environ['PARALLEL_TYPE'] = parallel_type

    if parallel_type == 'ds':
        command = f'deepspeed {extra_args}{run_file_name}'
    elif parallel_type == 'ddp':
        if len(extra_args) == 0:
            extra_args = '--standalone --nproc_per_node=gpu '

        command = f'torchrun {extra_args}{run_file_name}'
    else:
        command = f'python3 {run_file_name}'

    print(f'run command {command}')
    os.system(command)
